import numpy as np
import tensorflow as tf
import mat73

import sys
import os

class DataGenerator:
    def __init__(self, config,file_num):
        self.config = config

        train_in_file = os.path.join(self.config.data_dir,self.config.train_input+str(file_num).zfill(3)+'.npy')
        train_out_file = os.path.join(self.config.data_dir,self.config.train_output+str(file_num).zfill(3)+'.npy')

        print('*** LOADING TRAINING INPUT DATA ***')
        data_in=np.load(train_in_file)        
        print('*** LOADING TRAINING OUTPUT DATA ***')
        data_out=np.load(train_out_file)
        
       
        if self.config.noise:
            noise = np.random.normal(0,self.config.mean_noise,np.shape(data_in))
            data_in=data_in+noise

        
        self.input =data_in
        self.output = data_out
        
    
        self.len = self.input.shape[0]
        
       
    def next_batch(self, batch_size):
        idx = np.random.choice(self.len, batch_size)
        yield self.input[idx], self.output[idx]

class ValDataGenerator:
    def __init__(self, config):
        self.config = config

        test_in_file = os.path.join(self.config.data_dir, self.config.test_input)
        test_out_file = os.path.join(self.config.data_dir, self.config.test_output)

        print('*** LOADING TESTING INPUT DATA ***')
        data_in=np.load(test_in_file)

        print('*** LOADING TESTING OUTPUT DATA ***')
        data_out = np.load(test_out_file)
        
        if self.config.noise:
            noise = np.random.normal(0,self.config.mean_noise,np.shape(data_in))
            data_in=data_in+noise
        
        self.input =data_in
        self.output = data_out
       
        self.len = self.input.shape[0]

    def next_batch(self, batch_size):
        idx = np.random.choice(self.len, batch_size)
        yield self.input[idx], self.output[idx]